package org.funny.nn.som.demo;

import org.funny.nn.som.algorithm.Som;

import javax.swing.*;
import java.awt.*;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;

/**
 * 用SOM解决TSP的范例，本例有动画效果，
 * 但是因为提炼 数据模型，以及动画效果，可能要有JAVA基础的人，才比较容易看懂。
 * 原始容易理解的的被放到old 包下。那个作为教学还是不错的模型。
 * @author LinLW
 *
 */
public class TspDemo extends JFrame {

    public TspDemo() {
        this.setTitle("SOM 解决TSP 演示");// 设置窗体的标题
        DrawingPanel panel = new DrawingPanel();
        this.getContentPane().add(panel);
        this.setSize(600, 600);// 设置窗体的大小
//        this.setLocation(260, 150);// 设置窗体的初始位置
        this.setResizable(false);
        this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        // this.setBounds(260, 150, 300, 300);//设置窗体的坐标、大小，相当于前面2行代码

        List<City> cities=readCities();

        float[][] inputMartix=new float[cities.size()][2];
        for(int i=0;i<cities.size();i++){
            City c=cities.get(i);
            inputMartix[i]=new float[]{(float)c.getX(), (float)c.getY()};
        }
        normalize(inputMartix);//归一化

        BlockingQueue<float[][]> queue=new LinkedBlockingQueue<>(1);
        //计算线程
        Thread t=new Thread(()->{
            Som som=new Som();
            TspDemo.this.som=som;
            som.setIterations(200); //设置运行200轮
            som.setDecayedRate(1.0F-(float) Math.E/200);//设置衰减系数，查多200轮的时候，邻居半径衰减到 1.0左右为宜
            som.setLearningRate(0.618F);//学习率，一开始的学习率不宜太高，否则竞争学习可能会过度。
            som.setNetworkSize(10*inputMartix.length);// 多一些神经网络节点计算准确率高一些。不建议小于输入的5倍。
            som.setNeighborRadix(inputMartix.length); //邻居半径
            som.setInput(inputMartix);

            //设置距离计算方式，地图距离。
            som.setDistanceCalculator((a, b)->(float)Math.sqrt((a[0]-b[0])*(a[0]-b[0])+(a[1]-b[1])*(a[1]-b[1])));

            //设置一个监视学习过程的钩子。
            som.setDataWatcher((iter,finish,network)-> {
                try {
                    if(iter<5||finish==0) {
                        //前五轮，我们把每个输入对图形的影响都拿去显示。后续的，每轮只显示初始状态。
                        queue.put((float[][])network);
                        TspDemo.this.iter=iter;
                        TspDemo.this.finish=finish;
                    }
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            });

            som.train4tsp();//开始学习。

            //输出顺序
            int[] route=som.getRoute();
            for(int idx:route){
                System.out.println(cities.get(idx).getCity());
            }
        });
        t.start();
        this.input=inputMartix;
        //控制动画刷新线程  两个线程用blocking queue来协调。
        Thread t2=new Thread(()->{
            while(true){
                try {
                    TspDemo.this.network=queue.poll(1000,TimeUnit.MILLISECONDS);
                    if(TspDemo.this.network==null){
                        break;
                    }
                    panel.repaint();
                    //机器计算太块，每秒控制只能显示10帧，否则，肉眼看不过来。
                    //初始场景停留1.5秒给肉眼反应时间
                    if(TspDemo.this.finish==0&& TspDemo.this.iter==0){
                        Thread.sleep(1500);
                    }
                    Thread.sleep(100);
                } catch (Exception e) {
                    e.printStackTrace();
                    break;
                }
            }
        });
        t2.start();
    }

    /* 几个用于计算线程 与 绘图线程件传值 */
    private float[][] input;
    private float[][] network;
    private Som som;
    private int iter;
    private int finish;

    public class DrawingPanel extends JPanel{
        public void paintComponent(Graphics g) {
            super.paintComponent(g);

            int[] route=som.getRoute();
            for(int i=0;i<route.length;i++){
                float[] dot1=input[route[i]];

                int dot2_idx=i+1;
                if(dot2_idx>=route.length){
                    dot2_idx=0;//最后一个节点划线到第一个节点，以便首尾相连
                }
                float[] dot2 =input[route[dot2_idx]];
                drawLine(dot1,dot2,Color.LIGHT_GRAY,1,g);//根据当前的神经网络层节点，推算旅行商路线
            }

            for(float[] dot:network){
                drawDot(dot, Color.GREEN,2, g); //绘制当前神经网络层节点
            }
            int i=0;
            for(float[] city:input){
                int width=3;
                if(finish>0&&++i==finish){
                    width=5;
                }
                drawDot(city, Color.RED,width, g); //绘制输入层节点(城市/不会变化)
            }
            g.setColor(Color.GRAY);
            g.drawString("第"+TspDemo.this.iter+"轮",520,40);

        }
        private void drawDot(float[] dot, Color color, int width,Graphics g) {
            int x=(int)(dot[0]*500)+10; //边框10像素不画内容。只画中间500像素
            int y=510-(int)(dot[1]*500);
            g.setColor(color);
            g.fillRect(x,y,width,width);
        }

        private void drawLine(float[] dot1,float[]dot2, Color color, int width,Graphics g) {
            int x=(int)(dot1[0]*500)+10; //边框10像素不画内容。只画中间500像素
            int y=510-(int)(dot1[1]*500);

            int tox=(int)(dot2[0]*500)+10;
            int toy=510-(int)(dot2[1]*500);
            g.setColor(color);
            g.drawLine(x,y,tox,toy);
        }
    }

    public static void main(String[] args){
        TspDemo demo=new TspDemo ();
        demo.setVisible(true);
    }
    public static List<City> readCities(){
        List<City> cities=new ArrayList<>();
        //固定提供一些基础数据即可
        cities.add(new City("沈阳市",123.429092,41.796768));
        cities.add(new City("长春市",125.324501,43.886841));
        cities.add(new City("哈尔滨市",126.642464,45.756966));
        cities.add(new City("北京市",116.405289,39.904987));
        cities.add(new City("天津市",117.190186,39.125595));
        cities.add(new City("呼和浩特市",111.751990,40.841490));
        cities.add(new City("银川市",106.232480,38.486440));
        cities.add(new City("太原市",112.549248,37.857014));
        cities.add(new City("石家庄市",114.502464,38.045475));
        cities.add(new City("济南市",117.000923,36.675808));
        cities.add(new City("郑州市",113.665413,34.757977));
        cities.add(new City("西安市",108.948021,34.263161));
        cities.add(new City("武汉市",114.298569,30.584354));
        cities.add(new City("南京市",118.76741,32.041546));
        cities.add(new City("合肥市",117.283043,31.861191));
        cities.add(new City("上海市",121.472641,31.231707));
        cities.add(new City("长沙市",112.982277,28.19409));
        cities.add(new City("南昌市",115.892151,28.676493));
        cities.add(new City("杭州市",120.15358,30.287458));
        cities.add(new City("福州市",119.306236,26.075302));
        cities.add(new City("广州市",113.28064,23.125177));
        cities.add(new City("台北市",121.5200760,25.0307240));
        cities.add(new City("海口市",110.199890,20.044220));
        cities.add(new City("南宁市",108.320007,22.82402));
        cities.add(new City("重庆市",106.504959,29.533155));
        cities.add(new City("昆明市",102.71225,25.040609));
        cities.add(new City("贵阳市",106.713478,26.578342));
        cities.add(new City("成都市",104.065735,30.659462));
        cities.add(new City("兰州市",103.834170,36.061380));
        cities.add(new City("西宁市",101.777820,36.617290));
        cities.add(new City("拉萨市",91.11450,29.644150));
        cities.add(new City("乌鲁木齐市",87.616880,43.826630));
        cities.add(new City("香港",114.165460,22.275340));
        cities.add(new City("澳门",113.549130,22.198750));


        return cities;
    }
    /**
     * 归一化，把城市的x和y 都换算成0-1之间数。最西边的X为0 最东边的X为1 同理最北边的Y为1最那边的Y为0
     * @param cities 所有城市列表
     */
    public static void normalize(float[][] cities){

        float maxX=Float.MIN_VALUE, maxY=Float.MIN_VALUE;
        float minX=Float.MAX_VALUE, minY=Float.MAX_VALUE;
        for(float[] p:cities){
            maxX=p[0]>maxX?p[0]:maxX;
            minX=p[0]<minX?p[0]:minX;
            maxY=p[1]>maxY?p[1]:maxY;
            minY=p[1]<minY?p[1]:minY;
        }
        for(float[] p :cities){
            p[0]=(p[0]-minX)/(maxX-minX);
            p[1]=(p[1]-minY)/(maxY-minY);
        }
    }





}